{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 语言建模(Language Modeling)\n",
    "语言建模中我们可以判断一句话是否接近人说的话，一般给出前面的词来预测后面的词，而语言建模的数据集是大量的语料(corpus)，将语料数据进行预处理，结果形式如下：\n",
    "- 输入：`<SOS> Life is short, I use python.`\n",
    "- 输出：`Life is short, I use python. <EOS>`\n",
    "\n",
    "使用这样的文本对来训练模型\n",
    "\n",
    "参考：[Language modeling tutorial in torchtext ](http://mlexplained.com/2018/02/15/language-modeling-tutorial-in-torchtext-practical-torchtext-part-2/) 与 [github](https://github.com/keitakurita/practical-torchtext/blob/master/Lesson%202%20torchtext%20for%20language%20modeling.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import spacy\n",
    "import torch\n",
    "import torchtext\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device('cuda' if use_cuda else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32              # 批次大小\n",
    "seq_len = 25                # 句子长度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 定义token函数\n",
    "`<`或`>`在spacy中当作非标点符号"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from spacy.symbols import ORTH\n",
    "NLP = spacy.load('en')                                          # 载入英语模型\n",
    "NLP.tokenizer.add_special_case('<eos>', [{ORTH: '<eos>'}])      # 终止符，表示文本中的特殊符号\n",
    "NLP.tokenizer.add_special_case('<sos>', [{ORTH: '<sos>'}])      # 起始符\n",
    "NLP.tokenizer.add_special_case('<unk>', [{ORTH: '<unk>'}])      # 用于标记不在词典范围内的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenizer(text):         \n",
    "    \"\"\"\n",
    "        func: 数据清洗及预处理，并返回token标记(字母与数字)\n",
    "        text：传入需要处理的文本字符串(str)\n",
    "    \"\"\"\n",
    "    text = re.sub(r'[\\*\\\"“”\\n\\.\\+\\-\\/\\=\\(\\)\\!;\\\\]', \" \", str(text))   # 滤除无用的字符串\n",
    "    text = re.sub(r'\\s+', ' ', text)               # 将多个空格合并为一个空格\n",
    "    text = re.sub(r'\\!+', '!', text)               # 将多个 ‘!’ 合并为一个\n",
    "    text = re.sub(r'\\,+', ',', text)               # 同上\n",
    "    text = re.sub(r'\\?+', '?', text)               # 同上\n",
    "    \n",
    "    # 仅返回字符和数字，滤除标点符号\n",
    "    return [token.text for token in NLP.tokenizer(text) if not (token.is_space or token.is_punct)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 定义Field类\n",
    "声明数据预处理的pipeline\n",
    "\n",
    "**注意：这里的batch_size=True设置没起到作用???**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "sos_token = '<sos>'\n",
    "eos_token = '<eos>'\n",
    "unk_token = '<unk>'    # 数据集中有<unk>，所以该声明可以不使用\n",
    "pad_token = '<pad>'    # 文本数据是大段的，与翻译等其他数据集不同，所以不需要padding ???"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.data import Field\n",
    "# 将字符串小写，使用tokenizer函数处理数据\n",
    "TEXT = Field(sequential=True, lower=True, tokenize=tokenizer, use_vocab=True, batch_first=True,\n",
    "            init_token=sos_token, eos_token=eos_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 构建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_root = os.path.join('../..', 'data','wikitext_tiny')     # 设置数据的根目录\n",
    "data_list = ['train', 'val', 'test']                          # 设置train,val,test路径\n",
    "train_txt, val_txt, test_txt = map(lambda x:os.path.join(data_root, 'wiki_'+x+'.tokens'), data_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "进过`LanguageModelingDataset`处理后，得到数据集为`{'text':[words,...]}`字典，只有一个'text'的keys，values是数据集的所有文本，在该方法处理后会自动的在每一行后面添加一个`<eos>`符号表示结束"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 1.79 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from torchtext.datasets  import LanguageModelingDataset\n",
    "# 创建数据集\n",
    "train_dataset, val_dataset, test_dataset = LanguageModelingDataset.splits(   # 同时处理多个数据集\n",
    "    path = data_root,       \n",
    "    train = 'wiki_train.tokens', validation = 'wiki_val.tokens', test = 'wiki_test.tokens',\n",
    "    text_field = TEXT)   # 设置newline_eos=False每行不会补<eos>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['homarus', 'gammarus', 'known', 'as', 'the', 'european', 'lobster']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_dataset[0].__dict__['text'])      # 训练集共50396个单词\n",
    "train_dataset[0].__dict__['text'][5:12]     # 显示训练集中的数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 使用训练集建立字典(词汇表)\n",
    "使用GloVe的200维度预训练词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 454 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 建立字典并使用预训练的词向量，词向量与程序在同一个路径下\n",
    "TEXT.build_vocab(train_dataset,  vectors=\"glove.6B.200d\")    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.字典的大小: 6421\n",
      "2.字典中的字符: ['<unk>', '<pad>', '<sos>', '<eos>', 'the', 'of', 'in']\n",
      "3.出现频率高的词: [('the', 3838), ('<unk>', 3354), ('of', 1514), ('in', 1445), ('and', 1421)]\n"
     ]
    }
   ],
   "source": [
    "print('1.字典的大小:', len(TEXT.vocab.itos))\n",
    "print('2.字典中的字符:', TEXT.vocab.itos[:7])                 # 显示词典中的前七个词\n",
    "print('3.出现频率高的词:', TEXT.vocab.freqs.most_common(5))   # 显示出现频率最高的词 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 创建迭代器\n",
    "BPTTIterator专门为语言建模构建的迭代器，它会生成一个timestep延迟的输入序列，属性text可以得到输入，target属性得到输出，结果类似如下：\n",
    "\n",
    "- 输入text：`<SOS> Life is short, I use python.`\n",
    "- 输出target：`Life is short, I use python. <EOS>`\n",
    "\n",
    "但使用BPTTIterator得到的结果与上述稍有不同，在<SOS>与<EOS>位置仍然是一个单词，而不是这种特殊符号"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext.data import BPTTIterator, Iterator\n",
    "\n",
    "train_iter, val_iter, test_iter = BPTTIterator.splits((train_dataset, val_dataset, test_dataset),\n",
    "    batch_size=batch_size, bptt_len=seq_len, device=-1, repeat=False)  # -1表示使用cpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.数据维度: torch.Size([25, 32])\n",
      "2.输入: tensor([   3,  411,  198,    3,    3,  411,  198,  199,   13,    4,  932,  450,\n",
      "          62,  350,  450,   12,    9,  203,    5,    0,  450,   20,    4,  408,\n",
      "        1281])\n",
      "3.输出: tensor([ 411,  198,    3,    3,  411,  198,  199,   13,    4,  932,  450,   62,\n",
      "         350,  450,   12,    9,  203,    5,    0,  450,   20,    4,  408, 1281,\n",
      "         976])\n",
      "4.查看字典键值: dict_keys(['batch_size', 'dataset', 'train', 'fields', 'text', 'target'])\n"
     ]
    }
   ],
   "source": [
    "it = iter(train_iter)                    # 生成的批数据维度顺序是 (seq,batch) \n",
    "a = next(it)\n",
    "print('1.数据维度:', a.text.shape)        # (seq_len, batch_size)\n",
    "print('2.输入:', a.text[:,0].squeeze())        # 输入比目标序列延迟一个step，但没有使用<SOS><EOS>标志???\n",
    "print('3.输出:', a.target[:,0].squeeze())  \n",
    "print('4.查看字典键值:', vars(a).keys())   # vars函数查看字典中的keys和values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.输入: of 60 cm 24 in and a mass of 6\n",
      "2.输出: 60 cm 24 in and a mass of 6 kilograms\n"
     ]
    }
   ],
   "source": [
    "a = next(it)\n",
    "def idx2sent(seq_tensor, show_len=10):\n",
    "    \"\"\"\n",
    "    func：将seq_tensor转换为一个句子\n",
    "    seq_tensor：需要转换为句子的tensor，batch_size必须为1 \n",
    "    show_len：显示的句子长度\n",
    "    \"\"\"\n",
    "    seq_flatten = seq_tensor.flatten()      # 变为一维的向量\n",
    "    seq_list = [TEXT.vocab.itos[idx] for idx in seq_flatten]\n",
    "    seq_str = ' '.join(seq_list[:show_len])\n",
    "    return seq_str\n",
    "text = idx2sent(a.text[:,0])\n",
    "target = idx2sent(a.target[:,0])\n",
    "print('1.输入:', text)\n",
    "print('2.输出:', target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**总结**\n",
    "\n",
    "上述数据预处理的方式是按照普通的数据集处理流程来的，但torchtext中已经提供了几个常用的数据集，可以直接导入，(其实这些类中的实现方式与上述方法相同)方法如下：\n",
    "```python\n",
    "from torchtext.data import Field\n",
    "from torchtext.datasets import WikiText2\n",
    "# 声明数据预处理方式\n",
    "TEXT = Field(sequential=True, lower=True, tokenize=custom_tokenizer, use_vocab=True)\n",
    "# 创建数据集\n",
    "train, val, test = WikiText2.splits(TEXT)\n",
    "```\n",
    "语言建模类中提供了`WikiText2`,`WikiText103`,`PennTreebank`3种数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 定义语言模型\n",
    "1. 需要注意的是lstm接收隐藏状态和细胞状态，并且是以(hidden_state, cell_state)的形式传入的，这里将两个状态写到一个`init_hidden`函数中，在模型中定义时默认是使用cpu，为了使用GPU则`model.to(device)`指定模型使用的设备，初始的状态同样也要指定设备，由于`init_hidden`函数返回的是`(hidden_state, cell_state)`元组，所以在外部指定设备时需要使用`GPU = lambda x:x.to(device)`和`hidden_0 = tuple(map(GPU, hidden_0))`将隐藏状态和细胞状态分别放到GPU上\n",
    "2. reset_history函数起到的作用是将当前的隐藏状态数据取出然后用于下一个时刻的隐藏状态输入，其实LSTM自己会这样处理，但lstm默认会跟踪整个数据集的隐藏状态，并进行反向传播，这会消耗大量内存，这样处理相当于是在时间上截断了传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LanguageModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, hidden_size,\n",
    "                 batch_size, n_layers=2, dropout=0.5):\n",
    "        \"\"\"\n",
    "        function: 语言模型的初始化\n",
    "        vocab_size：字典的大小\n",
    "        embed_size：词嵌入的维度大小\n",
    "        hidden_size：RNN中隐藏状态的维度\n",
    "        batch_size：批次数据的大小\n",
    "        n_layers：  RNN的层数，默认是1\n",
    "        dropout：词嵌入后进行的dropout正则化，默认为0.5\n",
    "        \"\"\"\n",
    "        super(LanguageModel, self).__init__()                 # 继承类的初始化\n",
    "        self.vocab_size = vocab_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.n_layers = n_layers\n",
    "        self.batch_size = batch_size                          # 批次大小\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.embed = nn.Embedding(vocab_size, embed_size)     # 嵌入为embed_size维\n",
    "        # LSTM的dropout在多层中使用，除了输出层其他层都使用dropout\n",
    "        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=n_layers) #, dropout=0)\n",
    "        self.predictor = nn.Linear(hidden_size, vocab_size)   # 输出维度为词典大小\n",
    "        \n",
    "        self.init_weights()                                   # 初始化词嵌入的权重\n",
    "        # 输入的文本是分批次的连续语料，所以需要保持不同批次间的RNN隐藏状态\n",
    "        self.hidden_state = self.init_hidden()      # 保留上个批次的隐藏状态\n",
    "        \n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.embed.weight.data.uniform_(-initrange, initrange)\n",
    "        self.predictor.weight.data.uniform_(-initrange, initrange)\n",
    "        self.predictor.bias.data.zero_()\n",
    "        \n",
    "    def init_hidden(self):\n",
    "        # 初始化隐藏状态和细胞状态\n",
    "        hidden_state = torch.zeros([self.n_layers, self.batch_size, self.hidden_size])\n",
    "        cell_state = torch.zeros_like(hidden_state)\n",
    "        # lstm接收参数是元组的形式，所以这里返回元组\n",
    "        return (hidden_state, cell_state)\n",
    "    \n",
    "    def reset_history(self):\n",
    "        # 将上个批次的状态数据送入下个批次使用，只有返回的是元组!!!\n",
    "        return tuple(state.data for state in self.hidden_state)\n",
    "        \n",
    "    def forward(self, x, hidden_state):\n",
    "        \"\"\"\n",
    "        x：输入的tensor (seq, batch)\n",
    "        hidden_state：lstm的起始隐藏状态 \n",
    "        \"\"\"\n",
    "        embedded = self.drop(self.embed(x))                     # (seq, batch, embed_size)\n",
    "        # 注意这里的lstm将当前时刻输出状态会送入下一个时刻继续使用\n",
    "        lstm_out, self.hidden_state = self.lstm(embedded, hidden_state) # (seq,batch,hidden)\n",
    "        seq_len, batch = lstm_out.shape[:2]                     # 获取lstm输出的维度\n",
    "        lstm_flatten = lstm_out.view(-1, self.hidden_size)      # (seq*batch, hidden_size)\n",
    "        predictor = self.predictor(lstm_flatten)                # (seq*batch, vocab_size)\n",
    "        \n",
    "        return predictor.view(seq_len, batch, self.vocab_size)  # 恢复维度 (seq,batch,vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 3, 6421])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试语言模型\n",
    "test_in = torch.randint(0,6421, [5,3], dtype=torch.long)\n",
    "test_in = test_in.to(device)\n",
    "test_model = LanguageModel(vocab_size=6421, embed_size=200, hidden_size=128,batch_size=3)\n",
    "test_model.to(device)\n",
    "hidden_0 = test_model.init_hidden()\n",
    "GPU = lambda x:x.to(device)\n",
    "hidden_0 = tuple(map(GPU, hidden_0))    # hidden_0返回是(hidden,cell)形式的元组，所以这样处理!!!\n",
    "out = test_model(test_in, hidden_0)\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.字典大小和嵌入的维度大小: torch.Size([6421, 200])\n"
     ]
    }
   ],
   "source": [
    "embedded_matrix = TEXT.vocab.vectors     # 已载入的GloVe词嵌入权重\n",
    "print('1.字典大小和嵌入的维度大小:', embedded_matrix.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LanguageModel(\n",
       "  (drop): Dropout(p=0.5)\n",
       "  (embed): Embedding(6421, 200)\n",
       "  (lstm): LSTM(200, 512, num_layers=2)\n",
       "  (predictor): Linear(in_features=512, out_features=6421, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab_size, embedded_size = embedded_matrix.shape\n",
    "# 定义语言模型\n",
    "languagemodel = LanguageModel(vocab_size, embedded_size, hidden_size=512, batch_size=batch_size)\n",
    "# 使用GloVe的词嵌入权重来初始化语言模型的词嵌入权重!!!\n",
    "languagemodel.embed.weight.data.copy_(embedded_matrix)\n",
    "languagemodel.to(device)          # 将模型放到GPU上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50396"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义损失函数及优化方法\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(languagemodel.parameters(), lr=1e-2)\n",
    "train_num = len(train_dataset[0].__dict__['text'])           # 训练集单词数目\n",
    "# train_num = len(test_dataset[0].__dict__['text'])\n",
    "val_num = len(val_dataset[0].__dict__['text'])               # 验证集单词的数目\n",
    "train_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\data\\field.py:322: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  return Variable(arr, volatile=not train)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1   Train Loss: 2.6147   Train ACC: 0.4057 Val Loss: 4.0217  Val Acc: 0.0499\n",
      "Epoch: 2   Train Loss: 1.7497   Train ACC: 0.5314 Val Loss: 4.2087  Val Acc: 0.0510\n",
      "Epoch: 3   Train Loss: 1.3635   Train ACC: 0.6201 Val Loss: 4.3051  Val Acc: 0.0522\n",
      "Epoch: 4   Train Loss: 1.1696   Train ACC: 0.6640 Val Loss: 4.4001  Val Acc: 0.0547\n",
      "Epoch: 5   Train Loss: 1.0392   Train ACC: 0.6957 Val Loss: 4.4664  Val Acc: 0.0571\n",
      "Epoch: 6   Train Loss: 0.9612   Train ACC: 0.7161 Val Loss: 4.5015  Val Acc: 0.0566\n",
      "Epoch: 7   Train Loss: 0.8826   Train ACC: 0.7397 Val Loss: 4.5693  Val Acc: 0.0573\n",
      "Epoch: 8   Train Loss: 0.8598   Train ACC: 0.7440 Val Loss: 4.6192  Val Acc: 0.0573\n",
      "Epoch: 9   Train Loss: 0.7921   Train ACC: 0.7646 Val Loss: 4.6598  Val Acc: 0.0581\n",
      "Epoch: 10  Train Loss: 0.7518   Train ACC: 0.7744 Val Loss: 4.7459  Val Acc: 0.0561\n",
      "Epoch: 11  Train Loss: 0.7139   Train ACC: 0.7850 Val Loss: 4.7548  Val Acc: 0.0542\n",
      "Epoch: 12  Train Loss: 0.6937   Train ACC: 0.7897 Val Loss: 4.8028  Val Acc: 0.0558\n",
      "Epoch: 13  Train Loss: 0.6661   Train ACC: 0.7983 Val Loss: 4.8304  Val Acc: 0.0579\n",
      "Epoch: 14  Train Loss: 0.6327   Train ACC: 0.8069 Val Loss: 4.8711  Val Acc: 0.0534\n",
      "Epoch: 15  Train Loss: 0.6240   Train ACC: 0.8105 Val Loss: 4.9277  Val Acc: 0.0540\n",
      "Epoch: 16  Train Loss: 0.6012   Train ACC: 0.8154 Val Loss: 4.9891  Val Acc: 0.0574\n",
      "Epoch: 17  Train Loss: 0.5880   Train ACC: 0.8210 Val Loss: 4.9923  Val Acc: 0.0557\n",
      "Epoch: 18  Train Loss: 0.5628   Train ACC: 0.8286 Val Loss: 5.0336  Val Acc: 0.0549\n",
      "Epoch: 19  Train Loss: 0.5657   Train ACC: 0.8274 Val Loss: 5.0509  Val Acc: 0.0551\n",
      "Epoch: 20  Train Loss: 0.5424   Train ACC: 0.8335 Val Loss: 5.0868  Val Acc: 0.0571\n",
      "Epoch: 21  Train Loss: 0.5307   Train ACC: 0.8373 Val Loss: 5.1252  Val Acc: 0.0551\n",
      "Epoch: 22  Train Loss: 0.5313   Train ACC: 0.8363 Val Loss: 5.1721  Val Acc: 0.0542\n",
      "Epoch: 23  Train Loss: 0.5256   Train ACC: 0.8410 Val Loss: 5.1699  Val Acc: 0.0603\n",
      "Epoch: 24  Train Loss: 0.5075   Train ACC: 0.8446 Val Loss: 5.2003  Val Acc: 0.0550\n",
      "Epoch: 25  Train Loss: 0.4966   Train ACC: 0.8480 Val Loss: 5.2421  Val Acc: 0.0560\n",
      "Epoch: 26  Train Loss: 0.4898   Train ACC: 0.8498 Val Loss: 5.2275  Val Acc: 0.0523\n",
      "Epoch: 27  Train Loss: 0.4847   Train ACC: 0.8503 Val Loss: 5.2636  Val Acc: 0.0516\n",
      "Epoch: 28  Train Loss: 0.4833   Train ACC: 0.8514 Val Loss: 5.3161  Val Acc: 0.0540\n",
      "Epoch: 29  Train Loss: 0.4726   Train ACC: 0.8537 Val Loss: 5.2913  Val Acc: 0.0527\n",
      "Epoch: 30  Train Loss: 0.4620   Train ACC: 0.8577 Val Loss: 5.3242  Val Acc: 0.0521\n",
      "Epoch: 31  Train Loss: 0.4674   Train ACC: 0.8565 Val Loss: 5.3325  Val Acc: 0.0522\n",
      "Epoch: 32  Train Loss: 0.4515   Train ACC: 0.8611 Val Loss: 5.3853  Val Acc: 0.0534\n",
      "Epoch: 33  Train Loss: 0.4460   Train ACC: 0.8616 Val Loss: 5.3866  Val Acc: 0.0532\n",
      "Epoch: 34  Train Loss: 0.4453   Train ACC: 0.8623 Val Loss: 5.4125  Val Acc: 0.0550\n",
      "Epoch: 35  Train Loss: 0.4392   Train ACC: 0.8660 Val Loss: 5.4360  Val Acc: 0.0545\n",
      "Epoch: 36  Train Loss: 0.4325   Train ACC: 0.8675 Val Loss: 5.4785  Val Acc: 0.0529\n",
      "Epoch: 37  Train Loss: 0.4266   Train ACC: 0.8675 Val Loss: 5.4571  Val Acc: 0.0540\n",
      "Epoch: 38  Train Loss: 0.4222   Train ACC: 0.8682 Val Loss: 5.4832  Val Acc: 0.0516\n",
      "Epoch: 39  Train Loss: 0.4205   Train ACC: 0.8700 Val Loss: 5.5237  Val Acc: 0.0535\n",
      "Epoch: 40  Train Loss: 0.4288   Train ACC: 0.8673 Val Loss: 5.5478  Val Acc: 0.0554\n",
      "Epoch: 41  Train Loss: 0.4263   Train ACC: 0.8680 Val Loss: 5.5504  Val Acc: 0.0519\n",
      "Epoch: 42  Train Loss: 0.4312   Train ACC: 0.8664 Val Loss: 5.5967  Val Acc: 0.0515\n",
      "Epoch: 43  Train Loss: 0.4217   Train ACC: 0.8688 Val Loss: 5.5797  Val Acc: 0.0538\n",
      "Epoch: 44  Train Loss: 0.4194   Train ACC: 0.8703 Val Loss: 5.6711  Val Acc: 0.0531\n",
      "Epoch: 45  Train Loss: 0.4141   Train ACC: 0.8712 Val Loss: 5.6723  Val Acc: 0.0490\n",
      "Epoch: 46  Train Loss: 0.4188   Train ACC: 0.8686 Val Loss: 5.6705  Val Acc: 0.0517\n",
      "Epoch: 47  Train Loss: 0.4296   Train ACC: 0.8651 Val Loss: 5.7297  Val Acc: 0.0529\n",
      "Epoch: 48  Train Loss: 0.4268   Train ACC: 0.8671 Val Loss: 5.7220  Val Acc: 0.0494\n",
      "Epoch: 49  Train Loss: 0.4274   Train ACC: 0.8662 Val Loss: 5.7057  Val Acc: 0.0518\n",
      "Epoch: 50  Train Loss: 0.4087   Train ACC: 0.8728 Val Loss: 5.7415  Val Acc: 0.0537\n"
     ]
    }
   ],
   "source": [
    "def train(model, train_loader, val_loader, epochs=1):\n",
    "    \"\"\"\n",
    "    func：训练模型\n",
    "    train_data：训练数据集\n",
    "    val_data：验证集\n",
    "    epochs：训练epochs\n",
    "    \"\"\"\n",
    "    pltLoss = []\n",
    "    running_loss = 0\n",
    "    running_acc = 0\n",
    "    GPU = lambda x:x.to(device)                           # 优先使用GPU\n",
    "    model = model.to(device)\n",
    "    hidden_state = model.init_hidden()                    # 初始化隐藏状态!!!\n",
    "    hidden_state = tuple(map(GPU, hidden_state))\n",
    "    for epoch in range(1, epochs+1): \n",
    "        model.train()                                     # 训练模式\n",
    "#         model.zero_grad()\n",
    "#         hidden_state = model.reset_history()\n",
    "        for data in train_loader:\n",
    "            x, y = data.text, data.target                 # (seq,batch)\n",
    "            x, y = map(GPU, [x, y])\n",
    "            preds = model(x, hidden_state)\n",
    "            preds_flatten = preds.view(-1, vocab_size)    # 将输出变为2维向量\n",
    "            y_flatten = y.view(-1)                        # 将标签也变为2维向量\n",
    "#             print(preds_flatten.shape, y_flatten.shape)\n",
    "            loss = criterion(preds_flatten, y_flatten)\n",
    "            print_loss = loss.item()                      # 获取loss数据   \n",
    "#             print(print_loss)\n",
    "            optimizer.zero_grad()                         # 更新参数\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "#           \n",
    "            _, out = torch.max(preds_flatten, dim=1)            \n",
    "            running_loss += print_loss * preds_flatten.shape[0]\n",
    "#             print(out.data[:20])\n",
    "            running_acc += torch.sum(out == y_flatten).item()\n",
    "        hidden_state = model.reset_history()   # !!!将当前批次的隐藏状态数据放到下一个批次\n",
    "        epoch_loss = running_loss / train_num                   # 计算epoch的损失\n",
    "        epoch_acc = running_acc / train_num\n",
    "        running_loss = 0                                        # 切记epoch完成后清零\n",
    "        running_acc = 0\n",
    "        pltLoss.append(epoch_loss)\n",
    "        \n",
    "        # 计算验证集上的误差\n",
    "        val_loss = 0\n",
    "        val_acc = 0\n",
    "        model.eval()                                            # 模型进入测试模型\n",
    "        with torch.no_grad():\n",
    "            hidden_state = model.reset_history()\n",
    "            for data_val in val_loader:\n",
    "                x_val, y_val = data_val.text, data_val.target   # (seq,batch)\n",
    "                x_val, y_val = map(GPU, [x_val, y_val])         # 数据放到GPU上\n",
    "                preds = model(x_val, hidden_state)  \n",
    "                preds_flatten = preds.view(-1, vocab_size)\n",
    "                y_flatten = y_val.view(-1)\n",
    "                loss = criterion(preds_flatten, y_flatten)\n",
    "                val_loss += loss.item() * preds_flatten.shape[0]\n",
    "                _, out_val = torch.max(preds_flatten, dim=1)\n",
    "                val_acc += torch.sum(out_val == y_flatten).item()\n",
    "                \n",
    "            epoch_val_loss = val_loss / val_num\n",
    "            epoch_val_acc = val_acc / val_num\n",
    "        print('Epoch: {:<3} Train Loss: {:<8.4f} Train ACC: {:<5.4f} Val Loss: {:<7.4f} Val Acc: {:5.4f}'.format(\n",
    "            epoch, epoch_loss, epoch_acc, epoch_val_loss, epoch_val_acc))\n",
    "    return model, pltLoss          \n",
    "\n",
    "train_model, pltloss = train(languagemodel, train_iter, test_iter, epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 存在问题\n",
    "当不使用dropout时，准确率会在0.2125上持续很多epoch，然后才继续上升，使用dropout会好一点，最大的问题是在关闭dropout时测试集上无法实现过拟合，准确率只能达到98%？？？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x2044e488b00>]"
      ]
     },
     "execution_count": 164,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAHRVJREFUeJzt3XmUnHW95/H3t6qrurqqel/Snc5OSMIWthDCBRW4LmwXdEZG3GG85I4jM+rBc+eOZ+Y64z3OGY+OXK84akQGnRGEGQS5ighoEFEDdNhC9oUlSafT3el9re03f1R16CTd6U5S3U/qqc/rnDpVz1NPnvo+h+JTv/49v+f3mHMOERHxl4DXBYiISP4p3EVEfEjhLiLiQwp3EREfUriLiPiQwl1ExIcU7iIiPqRwFxHxIYW7iIgPlXj1wXV1dW7RokVefbyISEHauHFjp3OufqrtPAv3RYsW0dLS4tXHi4gUJDN7azrbqVtGRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8SOEuIuJDCncRER8quHDf1tbHN3+zne7BhNeliIictgou3N/sHOLu9bvY3zPsdSkiIqetggv3mlgYgO4htdxFRCZTsOHepW4ZEZFJKdxFRHyo4MK9sixEwBTuIiLHU3DhHgwYVdGwwl1E5DgKLtwh2zWjcBcRmdyU4W5m881svZltNbPNZvb5Cba50sx6zeyV3OPvZ6bcrBq13EVEjms6N+tIAXc6514ys3Jgo5k95ZzbctR2f3DO3ZD/Eo9VEwuzu2NgNj5KRKQgTdlyd84dcM69lHvdD2wFmme6sOOpjoU1zl1E5DhOqM/dzBYBFwLPT/D2ZWb2qpn92szOyUNtk6qNhekeSpLJuJn8GBGRgjXtcDezOPAw8AXnXN9Rb78ELHTOnQ98B3h0kn2sNbMWM2vp6Og42ZqpjoVJZxx9I8mT3oeIiJ9NK9zNLEQ22H/qnPv50e875/qccwO5148DITOrm2C7dc65Vc65VfX1U968e1K1uQuZDumkqojIhKYzWsaAHwFbnXPfmmSbxtx2mNnq3H4P5bPQ8arH5pdRuIuITGg6o2UuBz4JbDKzV3LrvgwsAHDOfR/4MPBZM0sBw8AtzrkZ6xBXy11E5PimDHfn3HOATbHN3cDd+SpqKmq5i4gcX0FeoaqWu4jI8RVkuEdCQaLhoFruIiKTKMhwB6jWFAQiIpMq2HCvjYfp0lWqIiITKthwV8tdRGRyBRvutZr2V0RkUgUb7tUKdxGRSRVsuNfEwgwl0owk016XIiJy2inocAfdS1VEZCIKdxERH1K4i4j4kMJdRMSHCjfcowp3EZHJFGy4V5aFCJjCXURkIgUb7oGAZa9S1RQEIiLHKNhwh2y/e9eAwl1E5GgFHe7VMbXcRUQmUtDhrvllREQmVtDhXh0L64YdIiITKOhwr42F6R5KkMnM2L24RUQKUkGHe3U0TMZB73DS61JERE4rBR3utXHdKFtEZCIFHe7VuatUuzViRkTkCAUd7mPzyxzSWHcRkSMUdLiPdcuo5S4icqSCDvdqTR4mIjKhgg73SChILBxUuIuIHKWgwx10o2wRkYkUfLhrCgIRkWMVfLir5S4icqyCD/cahbuIyDEKP9yjCncRkaMVfrjHwwwn0wwn0l6XIiJy2ij8cB8b664LmUREDiv8cM9NQaDb7YmIvMM/4a6Wu4jIYVOGu5nNN7P1ZrbVzDab2ecn2MbM7J/MbJeZvWZmF81Mucc6HO6Do7P1kSIip72SaWyTAu50zr1kZuXARjN7yjm3Zdw21wJn5h6XAt/LPc+4d8JdN+wQERkzZcvdOXfAOfdS7nU/sBVoPmqzm4CfuKwNQJWZNeW92glUREIEA6aWu4jIOCfU525mi4ALgeePeqsZ2DtueR/H/gBgZmvNrMXMWjo6Ok6s0kkEAkZ1NKSWu4jIONMOdzOLAw8DX3DO9R399gT/5Ji7Vjvn1jnnVjnnVtXX159YpceRvUpVLXcRkTHTCnczC5EN9p86534+wSb7gPnjlucBrade3vRUR8N0q+UuInLYdEbLGPAjYKtz7luTbPYY8KncqJk1QK9z7kAe6zyu2niYQ2q5i4gcNp3RMpcDnwQ2mdkruXVfBhYAOOe+DzwOXAfsAoaA2/Jf6uSqo2G6h9RyFxEZM2W4O+eeY+I+9fHbOOBz+SrqRNXGwnQPJUhnHMHAcUsVESkKBX+FKmTndHcOeofVehcRAZ+Eu65SFRE5ki/CvTZWCugqVRGRMb4I9+pYCFDLXURkjC/CXS13EZEj+SLc1XIXETmSL8K9tCRIvLRELXcRkRxfhDtkW+9quYuIZPkm3GtipXTpKlUREcBP4R5Vy11EZIx/wj1WqpkhRURyfBTuIc0MKSKS46NwL2UkmWE4kfa6FBERz/ko3LNj3dV6FxHxVbiPXaWa8LgSERHv+Sjcx65SVbiLiPgo3LMt90MDCncREd+Ee3NVGSUBY3fHgNeliIh4zjfhHi4JsLQhzra2fq9LERHxnG/CHWBFYznbDvR5XYaIiOd8Fe7LGyto7R2hV3PMiEiR81W4r2gqB2Bbm1rvIlLcfBXuZzVWAKjfXUSKnq/CfU5FKVXRkFruIlL0fBXuZsaKxnK2HlDLXUSKm6/CHWBFYwXb2/rJZJzXpYiIeMZ34X5WUznDyTRvdw15XYqIiGd8F+4rDp9UVb+7iBQv34X7sjnlmKF+dxEpar4L97JwkMW1MbXcRaSo+S7cIXsxk8a6i0gx82e4N1bw1qEhBkdTXpciIuIJn4Z7dhqC7QfVeheR4uTLcD+rKTtiZru6ZkSkSPky3JuryoiXlmj6XxEpWlOGu5nda2btZvb6JO9faWa9ZvZK7vH3+S/zxAQCxvLGcraq5S4iRWo6Lff7gGum2OYPzrkLco+vnnpZp27sxh3OaRoCESk+U4a7c+5ZoGsWasmrFY3l9I2kONA74nUpIiKzLl997peZ2atm9mszOydP+zwlK5o0DYGIFK98hPtLwELn3PnAd4BHJ9vQzNaaWYuZtXR0dOThoye3PDccUtMQiEgxOuVwd871OecGcq8fB0JmVjfJtuucc6ucc6vq6+tP9aOPqyISormqTFeqikhROuVwN7NGM7Pc69W5fR461f3mw1lN5RoOKSJFqWSqDczsAeBKoM7M9gFfAUIAzrnvAx8GPmtmKWAYuMWdJkNUVjRWsH57ByPJNJFQ0OtyRERmzZTh7pz76BTv3w3cnbeK8mhFUznpjGNX+wDnNld6XY6IyKzx5RWqY965cYf63UWkuPg63BfVRiktCajfXUSKjq/DvSQYYNkcze0uIsXH1+EOuWkIdCGTiBQZ/4d7UwWdAwk6+ke9LkVEZNb4PtzPyl2pqta7iBQT34f72DQE2zQNgYgUEd+He228lLmVEZ5/o+AmthQROWm+D3eA685r4vc72ukZSnhdiojIrCiKcP/ghc0k045fbTrgdSkiIrOiKML9nLkVLG2I84uXW70uRURkVhRFuJsZH7qwmRfe7GJf95DX5YiIzLiiCHeAG8+fC8AvXlHrXUT8r2jCfX5NlEsWVfPoy/t102wR8b2iCXeAmy5oZmf7AFs0kZiI+FxRhfv15zURChqPvrzf61JERGZUUYV7dSzMe5Y18ItXWkln1DUjIv5VVOEO8KELm2nvH2XDntPiNq8iIjOi6ML9L89qIF5awiPqmhERHyu6cI+Eglx7biNPvN7GSDLtdTkiIjOi6MIdstMRDIymeHrrQa9LERGZEUUZ7muW1DKnopRHNR2BiPhUUYZ7MGDcdEEzz2xvp3tQM0WKiP8UZbgD3HTBXFIZzRQpIv5UtOF+dlMFy+bEeahlr6YjEBHfKdpwNzNuu3wxr+3r5aktOrEqIv5StOEOcPPF81hSF+Mbv9muK1ZFxFeKOtxLggG+9IHl7Gwf0EVNIuIrRR3uANee28h5zZXc9dQORlO6qElE/KHow93M+A/XrGB/zzA/3fC21+WIiORF0Yc7wBVn1nH50lruXr+LgdGU1+WIiJwyhXvO335gBV2DCe75wx6vSxEROWUK95zz51dx7bmN/PDZPRwaGPW6HBGRU6JwH+fO9y9nOJnmu+t3e12KiMgpUbiPs7Qhzs0Xz+f/bHiLfd1DXpcjInLSFO5H+cL7zgSDu57a6XUpIiInbcpwN7N7zazdzF6f5H0zs38ys11m9pqZXZT/MmdPU2UZt/3FIh5+aR+PvaopgUWkME2n5X4fcM1x3r8WODP3WAt879TL8tYX37eM1YtruPOhV/jT7k6vyxEROWFThrtz7lmg6zib3AT8xGVtAKrMrClfBXohEgryw0+uYnFdjL/5yUa2tfV5XZKIyAnJR597M7B33PK+3LqCVhkNcd9tq4mWBrn13hdp7Rn2uiQRkWnLR7jbBOsmnGLRzNaaWYuZtXR0dOTho2fW3Koy7rttNYOjKW79Xy/QO5z0uiQRkWnJR7jvA+aPW54HTHgm0jm3zjm3yjm3qr6+Pg8fPfPOaqrgB5+6mDc6B1n7kxZGkppcTEROf/kI98eAT+VGzawBep1zvrp33V+cUcc3bz6f59/o4s6HXiWjud9F5DRXMtUGZvYAcCVQZ2b7gK8AIQDn3PeBx4HrgF3AEHDbTBXrpZsuaOZg3wj/7fFtNFVG+E83nO11SSIik5oy3J1zH53ifQd8Lm8VncZuf9cSWntGuOe5N5hbVca/vmKx1yWJiExoynCXd5gZ//mGsznQO8w//GoLTZURrj2voEd9iohPafqBExQMGN++5UIunF/F5x98hZY3j3cJgIiINxTuJyESCnLPpy9hXlUZf/2TFnZ3DHhdkojIERTuJ6kmFua+21ZTEjA+fe8LtPePeF2SiMhhCvdTsKA2yr23XsKhgQSfua+FTt3kQ0ROEwr3U7RyXhXf/fiFbDnQx+qvPc3HfriB/73hLbXkRcRTlh3JOPtWrVrlWlpaPPnsmbDzYD+PvdrKrzYdYE/HIGZwyaIarju3kQ9e2ExVNOx1iSLiA2a20Tm3asrtFO755ZxjZ/sAj286wOObDrDj4AALa6Pcf/samqvKvC5PRAqcwv008cIbXXzmxy9SWRbigdvXML8m6nVJIlLAphvu6nOfYasX13D/X6+hfyTFR37wZ97sHPS6JBEpAgr3WXDevEoeuH0NI6kMH1n3Z42LF5EZp3CfJWfPreCB29eQzjg+8oMN7DjY73VJIuJjCvdZtLyxnJ+tvYyAwS3rNrClVbfvE5GZoXCfZUsb4jz4N5dRWhLgQ//zj9z11A6GE7oBiIjkl8LdA4vrYjzyby/nfWfP4du/3cl7v/V7fvlaK16NXBIR/1G4e6SxMsLdH7uIB9euoaIsxB33v8xH1m1gc2uv16WJiA8o3D126ZJafvnvruBrHzqXnQf7+avvPMeXH9lER7/mqRGRk6dwPw0EA8bHL13IM1+6ik9dtoiHXtzLld9Yz7ef3sngaMrr8kSkACncTyOV0RD/5cZzePKL7+bdy+q56+kdXPnNZ7j/+bdJpTNelyciBUThfhpaUh/ne5+4mIc/exkLa6J8+ZFNfOAfn+XJzW2kMzrpKiJT09wypznnHE9uOcjXf72NPZ2DVEdDvGdZPVetaOA9y+o126RIkZnu3DK6QfZpzsz4wDmNXL2igSc3H+S3Ww/yzI4OHn2llYDBRQuquWpFA3+1ci4LajUpmYhkqeVegNIZx2v7eli/rZ312zvYtD87fPJdZ9bx0dULeO9ZcwiXqMdNxI805W8ROdA7zP9t2ceDL+5lf88wdfEwH754PrdcMp9FdTGvyxORPFK4F6F0xvHsjg7uf+FtfretnXTGceXyej531VIuWVTjdXkikgcK9yLX1jvCgy/u5cd/fpOuwQSXLq7hjquXcsXSOszM6/JE5CQp3AWAoUSKB17Yy7pnd3Owb5Tz51dxx1VL+csVDQQCCnmRQqNwlyOMptI8vHE/3//9bt7uGqKxIsLiuhjza8qYVx1lXnX2eVFdlIbyiNflisgkNBRSjlBaEuRjly7gX62axz+/1sr6bR3s6x7ime0dtB81j83KeZVce24T153XyMJanZAVKURquQsjyTStPcPs7R5mS2sfT2xu49W9PQCc3VTBdec1cs25TZxRH1N/vYjH1C0jp2Rf9xBPvN7Gr19vY+Nb3QBUR0OcPbeCc+ZWcs7cCs6ZW8HiujhB9d2LzBqFu+RNW+8IT289yOv7e9nc2sf2tn4SuYnMykJBLl9ax7+4qJmrVzQQCQU9rlbE39TnLnnTWBnhE2sWHl5OpjPsah9gc2sfr+3r4YnX23h660EqIiVcv7KJD104j1ULqzUaR8RDarnLKUtnHH/a3ckjL+3nic1tDCXSzKsuY+W8SoYTaYaTaYYTaYZyr2PhEt5/zhyuX9nE8jnl6scXOQHqlhFPDCVSPLn5II+8vJ993UNEwyWUhYKUhYNEw9nntt4RNuw5RMbBGfUxrl85lxtWNrFsTrnX5Yuc9hTuclrrHBjlidfb+NVrB9jwxiGcgyX1Mc5rrmTZnPLcI8786qi6d0TGyWu4m9k1wLeBIHCPc+6/H/X+rcA3gP25VXc75+453j4V7jKmvX+E37zexu+2tbPj4AD7e4YPvxcJBVhSF6csHCTjHBkHmYw7/DoWDlITC1MbD1MTC1MTK6U2FmZBbZSVzZWUBDU7pvhL3sLdzILADuB9wD7gReCjzrkt47a5FVjlnLtjugUq3GUy/SNJdrYPsPNgP9vbBtjdMUAqkyFglntAwAwzY3A0RddggkODCbqHEkfcqSpeWsKli2u47IxaLl9ax/I55forQApePkfLrAZ2Oef25Hb8M+AmYMtx/5XISSqPhLhoQTUXLag+oX+XyTj6RpIcGkyw7UA/f9zdyZ93H+K329oBqI2FuWRRDcvmxDmjIc7ShvjhvwpE/GY64d4M7B23vA+4dILt/qWZvZtsK/+Lzrm9E2wjMmMCAaMqGqYqGuaM+jjXr2wCYH/PMH/alQ36l97u5sktbYw18M2guaqMJfVxmioiNFSU0lBeSn15KfXlERrKS6mNhykLBTWqRwrKdMJ9om/00X05/ww84JwbNbN/A/wYuPqYHZmtBdYCLFiw4ARLFTk5zVVl3LxqPjevmg9kp1t489Agu9sH2dWe7fZ5o3OQbQf66BwYZaJ7kIdLAlRHQ1RHw1RFQ1SVhYmWBgkFApQEjVAwQCholAQDVERCLKmPcUZ9nIW1UULq9xcPTCfc9wHzxy3PA1rHb+CcOzRu8YfA1yfakXNuHbAOsn3uJ1SpSJ5EQkFWNFaworHimPfSGcehwVHa+0bp6M8+uoay/fk9g0m6hhL0DCXY1THAcCJNKpMhlXYk0tnnVCZDMv3OVzsYMBbWRFlSH+eM+hhL6mMsqY+zpC5GTSw86V8D6YyjfyRJRSSk8wRyUqYT7i8CZ5rZYrKjYW4BPjZ+AzNrcs4dyC3eCGzNa5UisyQYMBrKI6c07XHfSJI9HYPs6RhgT8cgu3PPz+7oODxtA0BFpIQl9XEW1EQZTqbpHkxkf0gGE/QMJ3EOyiMlXLSgmlULq7l4UTUXzK8iGtaF5TK1Kb8lzrmUmd0B/IbsUMh7nXObzeyrQItz7jHg35vZjUAK6AJuncGaRU5rFZEQF8yv4oL5VUesT2cc+7uH2d2ZDfs3cs8v7+0mFi6hOhrmrMYKqmPZ7p+KSIg9nYNsfKuL//FUB5D98Tm7qYIzG+LUxMJUx8LUjnsuj4QIBsDMCI6NLgpAKBigOhrWjdOLiC5iEikAvUNJXtrbzcY3u2l5q4u9XcN0DSYYTqZPaD+VZSFq42Hq4qXUx7Mni+vjpbkTyRHqy0tzJ5FLNdvnaUoTh4n4SGU0xFXLG7hqecMR64cT6cNdOV2DCfpHUrkLvLKPdCY7RHQ0naF7MEHnwCiHBhJ0DIyyra2Pjv5R+kZSx3xewMheHBYrPXyRWF08e4FYVSxMRaSEyrIQFWUhKiIhKstCxEtLCJcEJvxRGEmmOdA7QmvPMPu7h9nXM0xb7zCGURYOUhoKUBYKEgkFiZQESKQz9Awl6RlO0jOUyL4eShIMGAtqoiyojbJw7Lk2RlNFJK/nJlLpDG93DR3uVtvdMUD3UDI7lUZuOo1I7nW4JIDD4Q5fYAcZ5wiYsXJeJZcuqfGkK03hLlLAysJBmsNlNFeVnfQ+RpLp7MnjgbETySO094/SOZDg0MAoXYMJNrdmRxL1T/BDcLSAZbuBwiUBwrmRQocGE0dsYwb18VLMsj9QI6kMiVTmiG1KApYdmRQNU1UWYm5VhGTaseVAH09uaTvixLUZlJYEiISCh58jJUEi4SBVZaHsSKdYmJpo9sepqixEIpWhZzhJ71CC3uGxH5Ik+3uGeevQ4BH7r4uXUhcPM5rK5CbBSzGSzBxxDmUyoaBx0YJq3nVmHVecWc95zZWz8leRumVEZNpGU2l6h5P0DafoG0nmXifpG0kxMJIimc6QTGdDL5lyJNJpMg4aKyLMrcr+CM2rLmNOReSY/v90xjGays4gWhoKEgtPfm1BOuNo7Rnm7a4h3jo0RFvvMCOpDKPJNCPJDCOpNKPJDEPJNL1D2RPVPYNJ+keP/XEyy54nqYpm/wJprIhwRkOcMw6PcIpTWRaasI5U7lizV0xz+DyHGYymMrz4ZhfP7ezkDzs72XKgD8h2jd1x1VJuf/eSk/pvoInDRESOkkhlst08w0lKSwJUlYUpj5TMynDTzoFR/rirk+d2dvKuZfXceP7ck9qPwl1ExIemG+4aFyUi4kMKdxERH1K4i4j4kMJdRMSHFO4iIj6kcBcR8SGFu4iIDyncRUR8yLOLmMysA3jrJP95HdCZx3IKSbEeu467uOi4J7fQOVc/1Y48C/dTYWYt07lCy4+K9dh13MVFx33q1C0jIuJDCncRER8q1HBf53UBHirWY9dxFxcd9ykqyD53ERE5vkJtuYuIyHEUXLib2TVmtt3MdpnZ33ldz0wxs3vNrN3MXh+3rsbMnjKznbnnai9rnAlmNt/M1pvZVjPbbGafz6339bGbWcTMXjCzV3PH/V9z6xeb2fO5437QzMJe1zoTzCxoZi+b2S9zy74/bjN708w2mdkrZtaSW5e373lBhbuZBYHvAtcCZwMfNbOzva1qxtwHXHPUur8DfuucOxP4bW7Zb1LAnc65s4A1wOdy/439fuyjwNXOufOBC4BrzGwN8HXgrtxxdwOf8bDGmfR5YOu45WI57quccxeMG/6Yt+95QYU7sBrY5Zzb45xLAD8DbvK4phnhnHsW6Dpq9U3Aj3Ovfwx8cFaLmgXOuQPOuZdyr/vJ/g/fjM+P3WUN5BZDuYcDrgb+X269744bwMzmAdcD9+SWjSI47knk7XteaOHeDOwdt7wvt65YzHHOHYBsCAINHtczo8xsEXAh8DxFcOy5rolXgHbgKWA30OOcG7urs1+/7/8I/C2QyS3XUhzH7YAnzWyjma3Nrcvb97wkDwXOponuYqvhPj5kZnHgYeALzrm+bGPO35xzaeACM6sCHgHOmmiz2a1qZpnZDUC7c26jmV05tnqCTX113DmXO+dazawBeMrMtuVz54XWct8HzB+3PA9o9agWLxw0syaA3HO7x/XMCDMLkQ32nzrnfp5bXRTHDuCc6wGeIXvOocrMxhphfvy+Xw7caGZvku1mvZpsS97vx41zrjX33E72x3w1efyeF1q4vwicmTuTHgZuAR7zuKbZ9Bjw6dzrTwO/8LCWGZHrb/0RsNU5961xb/n62M2sPtdix8zKgPeSPd+wHvhwbjPfHbdz7j865+Y55xaR/f/5d865j+Pz4zazmJmVj70G3g+8Th6/5wV3EZOZXUf2lz0I3Ouc+5rHJc0IM3sAuJLsLHEHga8AjwIPAQuAt4GbnXNHn3QtaGZ2BfAHYBPv9MF+mWy/u2+P3cxWkj2BFiTb6HrIOfdVM1tCtkVbA7wMfMI5N+pdpTMn1y3zJefcDX4/7tzxPZJbLAHud859zcxqydP3vODCXUREplZo3TIiIjINCncRER9SuIuI+JDCXUTEhxTuIiI+pHAXEfEhhbuIiA8p3EVEfOj/A3uiwHWexNQmAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x203919087b8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline \n",
    "plt.plot(pltloss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\data\\field.py:322: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  return Variable(arr, volatile=not train)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 11.9382  Test Acc: 0.1182 \n"
     ]
    }
   ],
   "source": [
    "test_num = len(test_dataset[0].__dict__['text'])\n",
    "# print(test_num)\n",
    "def test(model, test_loader):\n",
    "    \"\"\"\n",
    "    func：测试模型的性能\n",
    "    test_loader：测试用的数据集迭代器\n",
    "    \"\"\"\n",
    "    running_loss = 0\n",
    "    running_acc = 0\n",
    "    model.eval()     # 进行测试模式\n",
    "    model.cpu()\n",
    "    hidden_state = model.init_hidden()                    # 初始化隐藏状态!!!\n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            # x: (seq,batch) y: (seq,batch)\n",
    "            x, y = data.text, data.target                 # 读取测试数据\n",
    "            batch_preds = []\n",
    "            for step in range(x.shape[0]):                        # 沿着句子长度方向进行预测，这里seq_len=25\n",
    "#                 print(x[step:step+1,:].shape)\n",
    "                preds = model(x[step:step+1,:], hidden_state)     # 处理第一个词 input:(1,batch) ouput: (1,batch,vocab)\n",
    "                hidden_state = model.reset_history()              # 将hidden_state进行更新\n",
    "                batch_preds.append(preds)                         # 将每个step输出添加至列表\n",
    "            batch_preds = torch.cat(batch_preds, dim=0)           # get (seq,batch,vocab)\n",
    "            preds_flattten = batch_preds.view(-1, vocab_size)     # get (seq*batch, vocab)\n",
    "            y_flatten = y.view(-1)                                # get (seq*batch,)\n",
    "            loss = criterion(preds_flattten, y_flatten)           # 计算loss\n",
    "            \n",
    "            running_loss += loss.item() * preds_flattten.shape[0]  # 累加整个批次每个单词的损失\n",
    "            _, out = torch.max(preds_flattten, dim=1)              # 沿着第2维度进行预测，提取最大值所在的索引\n",
    "            running_acc += torch.sum(out == y_flatten).item()      # 切记加item()，否则输出为0\n",
    "        print('Test Loss: {:<8.4f} Test Acc: {:<7.4f}'.format(\n",
    "            running_loss/test_num, running_acc/test_num))\n",
    "test(train_model, test_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------\n",
      "起始词: street\n",
      "1.预测: <unk> school was by during the great depression <unk> <unk>\n",
      "----------------------------------------------------------------------\n",
      "起始词: teams\n",
      "1.预测: three years later however the overall effects of <unk> were\n",
      "----------------------------------------------------------------------\n",
      "起始词: was\n",
      "1.预测: meridian county to the city to keep the national philanthropist\n",
      "----------------------------------------------------------------------\n",
      "起始词: lauderdale\n",
      "1.预测: protagonist of the amc <unk> of the area to the\n",
      "----------------------------------------------------------------------\n",
      "起始词: 1911\n",
      "1.预测: 3 mitch <unk> who was ranked to by hand in\n",
      "----------------------------------------------------------------------\n",
      "起始词: the\n",
      "1.预测: from the <unk> lost running in 2012 the following day\n",
      "----------------------------------------------------------------------\n",
      "起始词: meridian\n",
      "1.预测: and pull of the new generation of the <unk> creek\n",
      "----------------------------------------------------------------------\n",
      "起始词: edible\n",
      "1.预测: time he was given only one carry for six of\n",
      "----------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "def sample_show(model, data_loader, show_num=1):\n",
    "    \"\"\"\n",
    "    func：显示一些预测的句子\n",
    "    model：使用的语言模型\n",
    "    data_loader：需要预测的数据集\n",
    "    show_num：显示的样本数\n",
    "    \"\"\"\n",
    "    print('-'*70)\n",
    "    model.eval()     # 进行测试模式\n",
    "    model.cpu()\n",
    "    hidden_state = model.init_hidden()                    # 初始化隐藏状态!!!\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(data_loader):\n",
    "            sample_index = random.choice(range(batch_size))       # 从该批次随机一个数\n",
    "            # x: (seq,batch) y: (seq,batch)\n",
    "            x, y = data.text, data.target                         # 读取测试数据\n",
    "            batch_sample = []\n",
    "            sample_input = x[:1, :]                               # (1,batch)\n",
    "            print('起始词:', idx2sent(sample_input[:, sample_index]))\n",
    "            for step in range(x.shape[0]):                        # 沿着句子长度方向进行预测，这里seq_len=25\n",
    "#                 print(x[step:step+1,:].shape)\n",
    "                preds = model(sample_input, hidden_state)         # 处理第一个词 input:(1,batch) ouput: (1,batch,vocab)\n",
    "                hidden_state = model.reset_history()              # 将hidden_state进行更新\n",
    "\n",
    "                _, sample_input = torch.max(preds, dim=2)         # (1, batch)，sample_input用于预测下一个单词\n",
    "                batch_sample.append(sample_input)\n",
    "            batch_sample = torch.cat(batch_sample, dim=0)         # get (seq,batch)\n",
    "            \n",
    "            sample_index = random.choice(range(batch_size))\n",
    "            show = idx2sent(batch_sample[:,sample_index])\n",
    "            target = idx2sent(y[:,sample_index])\n",
    "            print('1.预测:', show)\n",
    "#             print('2.目标:', target)\n",
    "            print('-'*70)\n",
    "            if i == 7:\n",
    "                break\n",
    "sample_show(train_model, train_iter)       "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 该程序还需继续优化！！！"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
